'''
Author: SlytherinGe
LastEditTime: 2021-11-17 19:27:20
'''
if __name__ == '__main__':
    import voc_label_utility as VL
else:
    import dataset_utilty.voc_label_utility as VL
import cv2
import os
import numpy as np

def draw_one_bbox(canvas, pos, cls_name, cls_property, color, is_rotation=False):

    #TODO draw rotated bbox 
    if is_rotation:
        p0 = (int(pos['x1']), int(pos['y1'])) #top-left
        p1 = (int(pos['x2']), int(pos['y2'])) #top-right
        p2 = (int(pos['x3']), int(pos['y3'])) #down-left
        p3 = (int(pos['x4']), int(pos['y4'])) #down-right
        bboxpts = np.array([p0,p1,p2,p3]).astype(np.float32)
        rbbox = cv2.minAreaRect(bboxpts)
        x, y, w, h, a = rbbox[0][0], rbbox[0][1], rbbox[1][0], rbbox[1][1], rbbox[2]
        if w == 0 or h == 0:
            return
        while not 0 > a >= -90:
            if a >= 0:
                a -= 90
                w, h = h, w
            else:
                a += 90
                w, h = h, w
        a = a / 180 * np.pi
        assert 0 > a >= -np.pi / 2
        a = a * 180 / np.pi
        pts = cv2.boxPoints(((x,y),(w,h),a))
        pts = np.int0(pts)
        
    else:
        p0 = (float(pos['xmin']), float(pos['ymin'])) #top-left
        p1 = (float(pos['xmax']), float(pos['ymin'])) #top-right
        p2 = (float(pos['xmin']), float(pos['ymax'])) #down-left
        p3 = (float(pos['xmax']), float(pos['ymax'])) #down-right
        # cv2.putText(canvas, cls_name + ':' + str(cls_property), p0, cv2.FONT_HERSHEY_COMPLEX_SMALL, 1, color)
        pts = np.array((p0, p1, p2, p3))

    cv2.drawContours(canvas, [pts], 0, color, thickness=2)
    # cv2.line(canvas, p0, p1, color,thickness=2)
    # cv2.line(canvas, p2, p3, color,thickness=2)
    # cv2.line(canvas, p0, p2, color,thickness=2)
    # cv2.line(canvas, p1, p3, color,thickness=2)


def draw_bboxes_on_one_image(img_path, anno_path, save_path, color_dict=None, display=False, is_rotation=False):

    img = cv2.imread(img_path)
    anno_info, anno_obj = VL.voc_label_preprocess(anno_path)

    bbox_color = (0, 255, 0)
    if color_dict is None:
        for obj in anno_obj:
            dict_obj = VL.voc_object_xml_element_resolver(obj)
            draw_one_bbox(img, dict_obj['bndbox'], dict_obj['name'], 1.00, bbox_color, is_rotation)
    else:
        for obj in anno_obj:
            dict_obj = VL.voc_object_xml_element_resolver(obj)
            draw_one_bbox(img, dict_obj['bndbox'], dict_obj['name'], 1.00, color_dict[dict_obj['name']], is_rotation)        

    if display:
        import matplotlib.pyplot as plt
        plt.figure()
        plt.imshow(img[:,:,[2,1,0]])
        plt.axis('off')
        plt.show()
    cv2.imwrite(save_path, img)

def draw_bboxes_on_images(img_files, anno_root, save_root, color_dict=None, display=False, is_rotation=False):
    '''
    img_files: a list of image file path to be drawn bboxes on
    anno_root: the root path to save the anno files
    save_root: the root path to save the drawn images
    color_dict: a dict to specify bbox colors for each classes, all boxes are green by default
    display: whether to show the result image, default is False
    '''


    for i in range(len(img_files)):
        img_path = img_files[i]
        name_l = os.path.basename(img_path).split('.')
        name = ''
        for i in range(len(name_l)-1):
            name += name_l[i]+'.'
        anno_path = os.path.join(anno_root, name+'xml')
        draw_bboxes_on_one_image(img_path, anno_path, os.path.join(save_root, os.path.basename(img_path)), color_dict, display, is_rotation)
    

# '''
# a example to draw boxes on a folder of images and save them in another folder using muti-thread
# '''
# if __name__ == '__main__':

#     IMG_ROOT = "/media/gejunyao/Disk1/Datasets/SSDD/VOC2012/JPEGImages/"
#     ANNO_ROOT = '/media/gejunyao/Disk1/Datasets/SSDD/VOC2012/Annotations/'
#     OUT_ROOT = '/media/gejunyao/Disk/Gejunyao/exp_results/visualization/results/ssdd_gt/'

#     img_files = os.listdir(IMG_ROOT)
#     for i in range(len(img_files)):
#         img_files[i] = os.path.join(IMG_ROOT, img_files[i])

#     color = dict(ship=(255, 0, 0))

#     if not os.path.exists(OUT_ROOT):
#         os.mkdir(OUT_ROOT)

#     for i in range(4):
#         pid = os.fork()
#         if pid == 0:
#             # child
#             img_files = img_files[:int(len(img_files)/2)]
#         elif pid > 0:
#             # parent
#             img_files = img_files[int(len(img_files)/2):]
#         else:
#             print('fork failed！')    

#     draw_bboxes_on_images(img_files, ANNO_ROOT, OUT_ROOT, color)

'''
根据标注信息把bbox画在以下文件夹中的所有图片上并将结果输出
'''
def draw_bboxes_on_dataset(img_root, anno_root, out_root, is_rotation=False):

    if not os.path.exists(out_root):
        os.mkdir(out_root)

    img_files = os.listdir(img_root)

    for i in range(4):
        pid = os.fork()
        if pid == 0:
            # child
            img_files = img_files[:int(len(img_files)/2)]
        elif pid > 0:
            # parent
            img_files = img_files[int(len(img_files)/2):]
        else:
            print('fork failed！')

    for img_f in img_files:
        name_l = img_f.split('.')
        name = ''
        for i in range(len(name_l)-1):
            name += name_l[i]+'.'
        img_path = os.path.join(img_root, img_f)
        anno_path = os.path.join(anno_root, name+'xml')
        img = cv2.imread(img_path)
        anno_info, anno_obj = VL.voc_label_preprocess(anno_path)
        for obj in anno_obj:
            dict_obj = VL.voc_object_xml_element_resolver(obj)
            draw_one_bbox(img, dict_obj['bndbox'], 'ship', 1.00, (255, 0, 0), is_rotation)
        cv2.imwrite(os.path.join(out_root, img_f), img)   

if __name__ == '__main__':


    IMG_ROOT = "/media/gejunyao/Disk1/Datasets/SSDD/VOC2012/JPEGImages/"
    ANNO_ROOT = '/media/gejunyao/Disk1/Datasets/SSDD/VOC2012/annotations_r/'
    OUT_ROOT = '/media/gejunyao/Disk/Gejunyao/exp_results/visualization/results/ssdd/ssdd_gt_r2'

    draw_bboxes_on_dataset(IMG_ROOT, ANNO_ROOT, OUT_ROOT, True)



'''
将不同文件夹下相同名称的图片放在一起进行对比
'''
# if __name__ == '__main__':
#     import os    
#     import numpy as np
#     AOI_ROOT = '/media/gejunyao/Disk/Gejunyao/develop/temp/results/faster_rcnn_aoi/faster_rcnn_show_ships/with_gt/'
#     SHIP_ROOT = '/media/gejunyao/Disk/Gejunyao/develop/temp/results/faster_rcnn_bigbox50/with_gt/'
#     OUT_ROOT = '/media/gejunyao/Disk/Gejunyao/develop/temp/results/faster_rcnn_bigbox50/ship_and_biggerbox/'
#     aoi_files = os.listdir(AOI_ROOT)
#     if not os.path.exists(OUT_ROOT):
#         print("output folder does not exsit, creating...")
#         os.mkdir(OUT_ROOT)

#     pid = os.fork()
#     if pid == 0:
#         # child
#         aoi_files = aoi_files[:int(len(aoi_files)/2)]
#     elif pid > 0:
#         # parent
#         aoi_files = aoi_files[int(len(aoi_files)/2):]
#     else:
#         print('fork failed！')
#     canvas = np.zeros((1200, 2400, 3), np.uint8)
#     canvas.fill(255)
#     cv2.putText(canvas, 'ship results', (280,60), cv2.FONT_HERSHEY_COMPLEX_SMALL, 2, (0,0,0))
#     cv2.putText(canvas, 'bigger box results', (1500,60), cv2.FONT_HERSHEY_COMPLEX_SMALL, 2, (0,0,0))
#     for f in aoi_files:
#         aoi_img = cv2.imread(os.path.join(AOI_ROOT, f))
#         ship_img = cv2.imread(os.path.join(SHIP_ROOT, f))
#         canvas[100:1100,100:1100] = aoi_img
#         canvas[100:1100,1300:2300] = ship_img        
#         cv2.imwrite(os.path.join(OUT_ROOT, f), canvas)

# if __name__ == '__main__':

#     import matplotlib.pyplot as plt
#     import os
#     IMG_ROOT = "/media/gejunyao/Disk1/Customized Datasets/VEDAI_VOC/Vehicules512_co/"
#     ANNO_ROOT = '/media/gejunyao/Disk1/Customized Datasets/VEDAI_VOC/VOC_Anno_512/'
#     OUT_ROOT = '/media/gejunyao/Disk1/Customized Datasets/VEDAI_VOC/512_with_gt'
#     img_files = os.listdir(IMG_ROOT)

#     if not os.path.exists(OUT_ROOT):
#         os.mkdir(OUT_ROOT)

#     for i in range(4):
#         pid = os.fork()
#         if pid == 0:
#             # child
#             img_files = img_files[:int(len(img_files)/2)]
#         elif pid > 0:
#             # parent
#             img_files = img_files[int(len(img_files)/2):]
#         else:
#             print('fork failed！')

#     for img_f in img_files:
#         name_l = img_f.split('_')[0]
#         name = "1024_"+name_l
#         img_path = os.path.join(IMG_ROOT, img_f)
#         anno_path = os.path.join(ANNO_ROOT, name+'.xml')
#         if not os.path.exists(anno_path):
#             print(img_path)
#             continue
#         img = cv2.imread(img_path)
#         anno_info, anno_obj = VL.voc_label_preprocess(anno_path)
#         for obj in anno_obj:
#             dict_obj = VL.voc_object_xml_element_resolver(obj)
#             draw_one_bbox(img, dict_obj['bndbox'], dict_obj['name'], 1.00, (255, 0, 0))
#         cv2.imwrite(os.path.join(OUT_ROOT, img_f), img)

# if __name__ == '__main__':
#     import matplotlib.pyplot as plt
#     img = cv2.imread('/media/gejunyao/Disk1/Customized Datasets/VEDAI_VOC/1024_with_gt/00000006_ir.png')
#     plt.figure()
#     plt.imshow(img)
#     plt.show()